/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/compilation_provider.h"

#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "xla/stream_executor/cuda/compilation_options.h"
#include "xla/stream_executor/cuda/compilation_provider_test.h"
#include "xla/stream_executor/cuda/composite_compilation_provider.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/cuda/driver_compilation_provider.h"
#include "xla/stream_executor/cuda/nvjitlink_compilation_provider.h"
#include "xla/stream_executor/cuda/nvjitlink_support.h"
#include "xla/stream_executor/cuda/nvptxcompiler_compilation_provider.h"
#include "xla/stream_executor/cuda/ptx_compiler_support.h"
#include "xla/stream_executor/cuda/subprocess_compilation.h"
#include "xla/stream_executor/cuda/subprocess_compilation_provider.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/threadpool.h"

namespace stream_executor::cuda {
using ::testing::_;
using ::testing::AnyOf;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::Not;

void CompilationProviderTest::SetUp() {
#ifdef ABSL_HAVE_MEMORY_SANITIZER
  if (GetParam() == kNvJitLinkCompilationProviderName) {
    GTEST_SKIP() << "nvjitlink is a precompiled and not instrumented binary "
                    "library, so it's not compatible with MSAN.";
  }
  if (GetParam() == kNvptxcompilerCompilationProviderName) {
    GTEST_SKIP() << "nvptxcompiler is a precompiled and not instrumented "
                    "binary library, so it's not compatible with MSAN.";
  }
#endif

  absl::string_view provider = GetParam();

  if (!IsLibNvJitLinkSupported() &&
      (provider == kNvJitLinkCompilationProviderName ||
       provider ==
           kCompositeNvptxCompilerAndNvJitLinkCompilationProviderName)) {
    GTEST_SKIP() << "nvjitlink is not supported in this build.";
  }

  if (!IsLibNvPtxCompilerSupported() &&
      (provider == kNvptxcompilerCompilationProviderName ||
       provider ==
           kCompositeNvptxCompilerAndNvJitLinkCompilationProviderName)) {
    GTEST_SKIP() << "nvptxcompiler is not supported in this build.";
  }

  TF_ASSERT_OK_AND_ASSIGN(compilation_provider_,
                          CreateCompilationProvider(GetParam()));
}

absl::StatusOr<std::unique_ptr<CompilationProvider>>
CompilationProviderTest::CreateCompilationProvider(absl::string_view name) {
  if (name == kSubprocessCompilationProviderName) {
    TF_ASSIGN_OR_RETURN(auto ptxas,
                        FindCudaExecutable("ptxas", "/does/not/exist"));
    TF_ASSIGN_OR_RETURN(auto nvlink,
                        FindCudaExecutable("nvlink", "/does/not/exist"));
    return std::make_unique<SubprocessCompilationProvider>(ptxas, nvlink);
  }

  if (name == kNvJitLinkCompilationProviderName) {
    return std::make_unique<NvJitLinkCompilationProvider>();
  }

  if (name == kNvptxcompilerCompilationProviderName) {
    return std::make_unique<NvptxcompilerCompilationProvider>();
  }

  if (name == kDriverCompilationProviderName) {
    return std::make_unique<DriverCompilationProvider>();
  }

  if (name == kCompositeNvptxCompilerAndNvJitLinkCompilationProviderName) {
    std::vector<std::unique_ptr<CompilationProvider>> providers;
    providers.push_back(std::make_unique<NvptxcompilerCompilationProvider>());
    providers.push_back(std::make_unique<NvJitLinkCompilationProvider>());
    return CompositeCompilationProvider::Create(std::move(providers));
  }

  return absl::NotFoundError(
      absl::StrCat("Unknown compilation provider: ", name));
}

TEST_P(CompilationProviderTest, NameIsNotEmpty) {
  EXPECT_THAT(compilation_provider()->name(), Not(IsEmpty()));
}

// Generated by the following command:
//
// echo "__device__ int magic() { return 42; }" |
//   nvcc -o - -rdc true --ptx --x cu -
//
constexpr const char kDependeePtx[] = R"(
.version 8.0
.target sm_80
.address_size 64

        // .globl       _Z5magicv

.visible .func  (.param .b32 func_retval0) _Z5magicv()
{
        .reg .b32       %r<2>;

        mov.u32         %r1, 42;
        st.param.b32    [func_retval0+0], %r1;
        ret;
})";

// Generated by the following command:
//
// echo "__device__ int magic(); __global__ void kernel(int* output) \
//   { *output = magic(); }" | nvcc -o - -rdc true --ptx --x cu -
//
constexpr const char kDependentPtx[] = R"(
.version 8.0
.target sm_80
.address_size 64

        // .globl       _Z6kernelPi
.extern .func  (.param .b32 func_retval0) _Z5magicv
()
;

.visible .entry _Z6kernelPi(
        .param .u64 _Z6kernelPi_param_0
)
{
        .reg .b32       %r<2>;
        .reg .b64       %rd<3>;

        ld.param.u64    %rd1, [_Z6kernelPi_param_0];
        cvta.to.global.u64      %rd2, %rd1;
        { // callseq 0, 0
        .reg .b32 temp_param_reg;
        .param .b32 retval0;
        call.uni (retval0),
        _Z5magicv,
        (
        );
        ld.param.b32    %r1, [retval0+0];
        } // callseq 0
        st.global.u32   [%rd2], %r1;
        ret;
})";

// Generated by the following command:
//
// echo "__global__ void kernel(int* output) { *output = 42; }" |
//   nvcc -o - -rdc true --ptx --x cu -
//
constexpr const char kStandalonePtx[] = R"(
.version 8.0
.target sm_80
.address_size 64

        // .globl       _Z6kernelPi

.visible .entry _Z6kernelPi (
        .param .u64 _Z6kernelPi_param_0
)
{
        .reg .b32       %r<16>;
        .reg .b64       %rd<3>;


        ld.param.u64    %rd1, [_Z6kernelPi_param_0];
        cvta.to.global.u64      %rd2, %rd1;
        mov.u32         %r1, 42;
        st.global.u32   [%rd2], %r15;
        ret;

})";

// This has been generated from the following snippet:
/*
#include <cstring>

__global__ void spilling_kernel(float* out) {
    constexpr int kSize = 20;
    float local[kSize];
    std::memcpy(local, out, sizeof(local));
    for(int i = 0; i < kSize; i += 4) {

        float a = local[i+0];
        float b = local[i+1];
        float c = local[i+2];
        float d = local[i+3];
        local[i+0] = b;
        local[i+1] = c;
        local[i+2] = d;
        local[i+3] = a;
    }
    std::memcpy(out, local, sizeof(local));
}
*/
// Generated by `cat kernel.cu | nvcc -o - --ptx --x cu -
// and .maxnreg directive added by hand
constexpr const char kSpillingKernelPrefix[] = R"(
.version 8.0
.target sm_80
.address_size 64

	// .globl	_Z15spilling_kernelPf

.visible .entry _Z15spilling_kernelPf(
	.param .u64 _Z15spilling_kernelPf_param_0
) .maxnreg 24
{
	.reg .b32 	%r<201>;
	.reg .b64 	%rd<3>;
	ld.param.u64 	%rd1, [_Z15spilling_kernelPf_param_0];
	cvta.to.global.u64 	%rd2, %rd1;
	ld.global.u8 	%r1, [%rd2];
	ld.global.u8 	%r2, [%rd2+1];
	prmt.b32 	%r3, %r2, %r1, 30212;
	ld.global.u8 	%r4, [%rd2+2];
	ld.global.u8 	%r5, [%rd2+3];
	prmt.b32 	%r6, %r5, %r4, 30212;
	ld.global.u8 	%r7, [%rd2+4];
	ld.global.u8 	%r8, [%rd2+5];
	prmt.b32 	%r9, %r8, %r7, 30212;
	ld.global.u8 	%r10, [%rd2+6];
	ld.global.u8 	%r11, [%rd2+7];
	prmt.b32 	%r12, %r11, %r10, 30212;
	ld.global.u8 	%r13, [%rd2+8];
	ld.global.u8 	%r14, [%rd2+9];
	prmt.b32 	%r15, %r14, %r13, 30212;
	ld.global.u8 	%r16, [%rd2+10];
	ld.global.u8 	%r17, [%rd2+11];
	prmt.b32 	%r18, %r17, %r16, 30212;
	ld.global.u8 	%r19, [%rd2+12];
	ld.global.u8 	%r20, [%rd2+13];
	prmt.b32 	%r21, %r20, %r19, 30212;
	ld.global.u8 	%r22, [%rd2+14];
	ld.global.u8 	%r23, [%rd2+15];
	prmt.b32 	%r24, %r23, %r22, 30212;
	ld.global.u8 	%r25, [%rd2+16];
	ld.global.u8 	%r26, [%rd2+17];
	prmt.b32 	%r27, %r26, %r25, 30212;
	ld.global.u8 	%r28, [%rd2+18];
	ld.global.u8 	%r29, [%rd2+19];
	prmt.b32 	%r30, %r29, %r28, 30212;
	ld.global.u8 	%r31, [%rd2+20];
	ld.global.u8 	%r32, [%rd2+21];
	prmt.b32 	%r33, %r32, %r31, 30212;
	ld.global.u8 	%r34, [%rd2+22];
	ld.global.u8 	%r35, [%rd2+23];
	prmt.b32 	%r36, %r35, %r34, 30212;
	ld.global.u8 	%r37, [%rd2+24];
	ld.global.u8 	%r38, [%rd2+25];
	prmt.b32 	%r39, %r38, %r37, 30212;
	ld.global.u8 	%r40, [%rd2+26];
	ld.global.u8 	%r41, [%rd2+27];
	prmt.b32 	%r42, %r41, %r40, 30212;
	ld.global.u8 	%r43, [%rd2+28];
	ld.global.u8 	%r44, [%rd2+29];
	prmt.b32 	%r45, %r44, %r43, 30212;
	ld.global.u8 	%r46, [%rd2+30];
	ld.global.u8 	%r47, [%rd2+31];
	prmt.b32 	%r48, %r47, %r46, 30212;
	ld.global.u8 	%r49, [%rd2+32];
	ld.global.u8 	%r50, [%rd2+33];
	prmt.b32 	%r51, %r50, %r49, 30212;
	ld.global.u8 	%r52, [%rd2+34];
	ld.global.u8 	%r53, [%rd2+35];
	prmt.b32 	%r54, %r53, %r52, 30212;
	ld.global.u8 	%r55, [%rd2+36];
	ld.global.u8 	%r56, [%rd2+37];
	prmt.b32 	%r57, %r56, %r55, 30212;
	ld.global.u8 	%r58, [%rd2+38];
	ld.global.u8 	%r59, [%rd2+39];
	prmt.b32 	%r60, %r59, %r58, 30212;
	ld.global.u8 	%r61, [%rd2+40];
	ld.global.u8 	%r62, [%rd2+41];
	prmt.b32 	%r63, %r62, %r61, 30212;
	ld.global.u8 	%r64, [%rd2+42];
	ld.global.u8 	%r65, [%rd2+43];
	prmt.b32 	%r66, %r65, %r64, 30212;
	ld.global.u8 	%r67, [%rd2+44];
	ld.global.u8 	%r68, [%rd2+45];
	prmt.b32 	%r69, %r68, %r67, 30212;
	ld.global.u8 	%r70, [%rd2+46];
	ld.global.u8 	%r71, [%rd2+47];
	prmt.b32 	%r72, %r71, %r70, 30212;
	ld.global.u8 	%r73, [%rd2+48];
	ld.global.u8 	%r74, [%rd2+49];
	prmt.b32 	%r75, %r74, %r73, 30212;
	ld.global.u8 	%r76, [%rd2+50];
	ld.global.u8 	%r77, [%rd2+51];
	prmt.b32 	%r78, %r77, %r76, 30212;
	ld.global.u8 	%r79, [%rd2+52];
	ld.global.u8 	%r80, [%rd2+53];
	prmt.b32 	%r81, %r80, %r79, 30212;
	ld.global.u8 	%r82, [%rd2+54];
	ld.global.u8 	%r83, [%rd2+55];
	prmt.b32 	%r84, %r83, %r82, 30212;
	ld.global.u8 	%r85, [%rd2+56];
	ld.global.u8 	%r86, [%rd2+57];
	prmt.b32 	%r87, %r86, %r85, 30212;
	ld.global.u8 	%r88, [%rd2+58];
	ld.global.u8 	%r89, [%rd2+59];
	prmt.b32 	%r90, %r89, %r88, 30212;
	ld.global.u8 	%r91, [%rd2+60];
	ld.global.u8 	%r92, [%rd2+61];
	prmt.b32 	%r93, %r92, %r91, 30212;
	ld.global.u8 	%r94, [%rd2+62];
	ld.global.u8 	%r95, [%rd2+63];
	prmt.b32 	%r96, %r95, %r94, 30212;
	ld.global.u8 	%r97, [%rd2+64];
	ld.global.u8 	%r98, [%rd2+65];
	prmt.b32 	%r99, %r98, %r97, 30212;
	ld.global.u8 	%r100, [%rd2+66];
	ld.global.u8 	%r101, [%rd2+67];
	prmt.b32 	%r102, %r101, %r100, 30212;
	ld.global.u8 	%r103, [%rd2+68];
	ld.global.u8 	%r104, [%rd2+69];
	prmt.b32 	%r105, %r104, %r103, 30212;
	ld.global.u8 	%r106, [%rd2+70];
	ld.global.u8 	%r107, [%rd2+71];
	prmt.b32 	%r108, %r107, %r106, 30212;
	ld.global.u8 	%r109, [%rd2+72];
	ld.global.u8 	%r110, [%rd2+73];
	prmt.b32 	%r111, %r110, %r109, 30212;
	ld.global.u8 	%r112, [%rd2+74];
	ld.global.u8 	%r113, [%rd2+75];
	prmt.b32 	%r114, %r113, %r112, 30212;
	ld.global.u8 	%r115, [%rd2+76];
	ld.global.u8 	%r116, [%rd2+77];
	prmt.b32 	%r117, %r116, %r115, 30212;
	ld.global.u8 	%r118, [%rd2+78];
	ld.global.u8 	%r119, [%rd2+79];
	prmt.b32 	%r120, %r119, %r118, 30212;
	prmt.b32 	%r121, %r12, %r9, 4180;
	st.global.u8 	[%rd2], %r121;
	shr.u32 	%r122, %r121, 24;
	st.global.u8 	[%rd2+3], %r122;
	shr.u32 	%r123, %r121, 16;
	st.global.u8 	[%rd2+2], %r123;
	shr.u32 	%r124, %r121, 8;
	st.global.u8 	[%rd2+1], %r124;
	prmt.b32 	%r125, %r18, %r15, 4180;
	st.global.u8 	[%rd2+4], %r125;
	shr.u32 	%r126, %r125, 24;
	st.global.u8 	[%rd2+7], %r126;
	shr.u32 	%r127, %r125, 16;
	st.global.u8 	[%rd2+6], %r127;
	shr.u32 	%r128, %r125, 8;
	st.global.u8 	[%rd2+5], %r128;
	prmt.b32 	%r129, %r24, %r21, 4180;
	st.global.u8 	[%rd2+8], %r129;
	shr.u32 	%r130, %r129, 24;
	st.global.u8 	[%rd2+11], %r130;
	shr.u32 	%r131, %r129, 16;
	st.global.u8 	[%rd2+10], %r131;
	shr.u32 	%r132, %r129, 8;
	st.global.u8 	[%rd2+9], %r132;
	prmt.b32 	%r133, %r6, %r3, 4180;
	st.global.u8 	[%rd2+12], %r133;
	shr.u32 	%r134, %r133, 24;
	st.global.u8 	[%rd2+15], %r134;
	shr.u32 	%r135, %r133, 16;
	st.global.u8 	[%rd2+14], %r135;
	shr.u32 	%r136, %r133, 8;
	st.global.u8 	[%rd2+13], %r136;
	prmt.b32 	%r137, %r36, %r33, 4180;
	st.global.u8 	[%rd2+16], %r137;
	shr.u32 	%r138, %r137, 24;
	st.global.u8 	[%rd2+19], %r138;
	shr.u32 	%r139, %r137, 16;
	st.global.u8 	[%rd2+18], %r139;
	shr.u32 	%r140, %r137, 8;
	st.global.u8 	[%rd2+17], %r140;
	prmt.b32 	%r141, %r42, %r39, 4180;
	st.global.u8 	[%rd2+20], %r141;
	shr.u32 	%r142, %r141, 24;
	st.global.u8 	[%rd2+23], %r142;
	shr.u32 	%r143, %r141, 16;
	st.global.u8 	[%rd2+22], %r143;
	shr.u32 	%r144, %r141, 8;
	st.global.u8 	[%rd2+21], %r144;
	prmt.b32 	%r145, %r48, %r45, 4180;
	st.global.u8 	[%rd2+24], %r145;
	shr.u32 	%r146, %r145, 24;
	st.global.u8 	[%rd2+27], %r146;
	shr.u32 	%r147, %r145, 16;
	st.global.u8 	[%rd2+26], %r147;
	shr.u32 	%r148, %r145, 8;
	st.global.u8 	[%rd2+25], %r148;
	prmt.b32 	%r149, %r30, %r27, 4180;
	st.global.u8 	[%rd2+28], %r149;
	shr.u32 	%r150, %r149, 24;
	st.global.u8 	[%rd2+31], %r150;
	shr.u32 	%r151, %r149, 16;
	st.global.u8 	[%rd2+30], %r151;
	shr.u32 	%r152, %r149, 8;
	st.global.u8 	[%rd2+29], %r152;
	prmt.b32 	%r153, %r60, %r57, 4180;
	st.global.u8 	[%rd2+32], %r153;
	shr.u32 	%r154, %r153, 24;
	st.global.u8 	[%rd2+35], %r154;
	shr.u32 	%r155, %r153, 16;
	st.global.u8 	[%rd2+34], %r155;
	shr.u32 	%r156, %r153, 8;
	st.global.u8 	[%rd2+33], %r156;
	prmt.b32 	%r157, %r66, %r63, 4180;
	st.global.u8 	[%rd2+36], %r157;
	shr.u32 	%r158, %r157, 24;
	st.global.u8 	[%rd2+39], %r158;
	shr.u32 	%r159, %r157, 16;
	st.global.u8 	[%rd2+38], %r159;
	shr.u32 	%r160, %r157, 8;
	st.global.u8 	[%rd2+37], %r160;
	prmt.b32 	%r161, %r72, %r69, 4180;
	st.global.u8 	[%rd2+40], %r161;
	shr.u32 	%r162, %r161, 24;
	st.global.u8 	[%rd2+43], %r162;
	shr.u32 	%r163, %r161, 16;
	st.global.u8 	[%rd2+42], %r163;
	shr.u32 	%r164, %r161, 8;
	st.global.u8 	[%rd2+41], %r164;
	prmt.b32 	%r165, %r54, %r51, 4180;
	st.global.u8 	[%rd2+44], %r165;
	shr.u32 	%r166, %r165, 24;
	st.global.u8 	[%rd2+47], %r166;
	shr.u32 	%r167, %r165, 16;
	st.global.u8 	[%rd2+46], %r167;
	shr.u32 	%r168, %r165, 8;
	st.global.u8 	[%rd2+45], %r168;
	prmt.b32 	%r169, %r84, %r81, 4180;
	st.global.u8 	[%rd2+48], %r169;
	shr.u32 	%r170, %r169, 24;
	st.global.u8 	[%rd2+51], %r170;
	shr.u32 	%r171, %r169, 16;
	st.global.u8 	[%rd2+50], %r171;
	shr.u32 	%r172, %r169, 8;
	st.global.u8 	[%rd2+49], %r172;
	prmt.b32 	%r173, %r90, %r87, 4180;
	st.global.u8 	[%rd2+52], %r173;
	shr.u32 	%r174, %r173, 24;
	st.global.u8 	[%rd2+55], %r174;
	shr.u32 	%r175, %r173, 16;
	st.global.u8 	[%rd2+54], %r175;
	shr.u32 	%r176, %r173, 8;
	st.global.u8 	[%rd2+53], %r176;
	prmt.b32 	%r177, %r96, %r93, 4180;
	st.global.u8 	[%rd2+56], %r177;
	shr.u32 	%r178, %r177, 24;
	st.global.u8 	[%rd2+59], %r178;
	shr.u32 	%r179, %r177, 16;
	st.global.u8 	[%rd2+58], %r179;
	shr.u32 	%r180, %r177, 8;
	st.global.u8 	[%rd2+57], %r180;
	prmt.b32 	%r181, %r78, %r75, 4180;
	st.global.u8 	[%rd2+60], %r181;
	shr.u32 	%r182, %r181, 24;
	st.global.u8 	[%rd2+63], %r182;
	shr.u32 	%r183, %r181, 16;
	st.global.u8 	[%rd2+62], %r183;
	shr.u32 	%r184, %r181, 8;
	st.global.u8 	[%rd2+61], %r184;
	prmt.b32 	%r185, %r108, %r105, 4180;
	st.global.u8 	[%rd2+64], %r185;
	shr.u32 	%r186, %r185, 24;
	st.global.u8 	[%rd2+67], %r186;
	shr.u32 	%r187, %r185, 16;
	st.global.u8 	[%rd2+66], %r187;
	shr.u32 	%r188, %r185, 8;
	st.global.u8 	[%rd2+65], %r188;
	prmt.b32 	%r189, %r114, %r111, 4180;
	st.global.u8 	[%rd2+68], %r189;
	shr.u32 	%r190, %r189, 24;
	st.global.u8 	[%rd2+71], %r190;
	shr.u32 	%r191, %r189, 16;
	st.global.u8 	[%rd2+70], %r191;
	shr.u32 	%r192, %r189, 8;
	st.global.u8 	[%rd2+69], %r192;
	prmt.b32 	%r193, %r120, %r117, 4180;
	st.global.u8 	[%rd2+72], %r193;
	shr.u32 	%r194, %r193, 24;
	st.global.u8 	[%rd2+75], %r194;
	shr.u32 	%r195, %r193, 16;
	st.global.u8 	[%rd2+74], %r195;
	shr.u32 	%r196, %r193, 8;
	st.global.u8 	[%rd2+73], %r196;
	prmt.b32 	%r197, %r102, %r99, 4180;
	shr.u32 	%r198, %r197, 24;
	st.global.u8 	[%rd2+79], %r198;
	shr.u32 	%r199, %r197, 16;
	st.global.u8 	[%rd2+78], %r199;
	shr.u32 	%r200, %r197, 8;
	st.global.u8 	[%rd2+77], %r200;
	st.global.u8 	[%rd2+76], %r197;
	ret;
}
)";

constexpr stream_executor::CudaComputeCapability kDefaultComputeCapability{8,
                                                                           0};

TEST_P(CompilationProviderTest, CompileStandaloneModuleSucceeds) {
  CompilationOptions options;
  TF_ASSERT_OK_AND_ASSIGN(
      Assembly module, compilation_provider()->Compile(
                           kDefaultComputeCapability, kStandalonePtx, options));
  EXPECT_FALSE(module.cubin.empty());
  EXPECT_EQ(module.compilation_log, std::nullopt);
}

TEST_P(CompilationProviderTest,
       CompileStandaloneModuleDumpsCompilationLogWhenRequested) {
  CompilationOptions options;
  options.dump_compilation_log = true;
  TF_ASSERT_OK_AND_ASSIGN(
      Assembly module, compilation_provider()->Compile(
                           kDefaultComputeCapability, kStandalonePtx, options));
  EXPECT_THAT(module.compilation_log, Optional(Not(IsEmpty())));
}

TEST_P(CompilationProviderTest, CompileStandaloneRelocatableModuleSucceeds) {
  if (!compilation_provider()->SupportsCompileToRelocatableModule()) {
    GTEST_SKIP();
  }

  CompilationOptions options;
  TF_ASSERT_OK_AND_ASSIGN(
      RelocatableModule module,
      compilation_provider()->CompileToRelocatableModule(
          kDefaultComputeCapability, kStandalonePtx, options));
  EXPECT_FALSE(module.cubin.empty());
  EXPECT_EQ(module.compilation_log, std::nullopt);
}

TEST_P(CompilationProviderTest,
       CompileStandaloneRelocatableModuleDumpsCompilationLogWhenRequested) {
  if (!compilation_provider()->SupportsCompileToRelocatableModule()) {
    GTEST_SKIP();
  }

  CompilationOptions options;
  options.dump_compilation_log = true;
  TF_ASSERT_OK_AND_ASSIGN(
      RelocatableModule module,
      compilation_provider()->CompileToRelocatableModule(
          kDefaultComputeCapability, kStandalonePtx, options));
  EXPECT_THAT(module.compilation_log, Optional(Not(IsEmpty())));
}

TEST_P(CompilationProviderTest,
       CompileToRelocatableModuleFailsWhenUnsupported) {
  if (compilation_provider()->SupportsCompileToRelocatableModule()) {
    GTEST_SKIP();
  }

  CompilationOptions options;
  EXPECT_THAT(compilation_provider()->CompileToRelocatableModule(
                  kDefaultComputeCapability, kStandalonePtx, options),
              absl_testing::StatusIs(absl::StatusCode::kUnavailable));
}

TEST_P(CompilationProviderTest, CompileAndLinkStandaloneModule) {
  if (!compilation_provider()->SupportsCompileAndLink()) {
    GTEST_SKIP() << "Compilation provider doesn't support CompileAndLink";
  }

  CompilationOptions options;
  TF_ASSERT_OK_AND_ASSIGN(
      Assembly assembly,
      compilation_provider()->CompileAndLink(kDefaultComputeCapability,
                                             {Ptx{kStandalonePtx}}, options));
  EXPECT_FALSE(assembly.cubin.empty());
}

TEST_P(CompilationProviderTest, CompileDependentRelocatableModuleSucceeds) {
  if (!compilation_provider()->SupportsCompileToRelocatableModule()) {
    GTEST_SKIP();
  }

  CompilationOptions options;
  TF_ASSERT_OK_AND_ASSIGN(
      RelocatableModule module,
      compilation_provider()->CompileToRelocatableModule(
          kDefaultComputeCapability, kDependentPtx, options));
  EXPECT_FALSE(module.cubin.empty());
}

TEST_P(CompilationProviderTest,
       CompileDependentModuleFailsWithUndefinedReferenceError) {
#ifdef ABSL_HAVE_THREAD_SANITIZER
  if (GetParam() == "nvjitlink") {
    GTEST_SKIP()
        << "nvjitlink fails with TSAN enabled due to some wrongly unlocked "
           "mutex. Note that this only happens when the compilation fails.";
  }
#endif

  CompilationOptions options;
  EXPECT_THAT(compilation_provider()->Compile(kDefaultComputeCapability,
                                              kDependentPtx, options),
              absl_testing::StatusIs(
                  _, AnyOf(HasSubstr("Undefined reference"),
                           HasSubstr("Unresolved extern function"))));
}

TEST_P(CompilationProviderTest,
       CompileAndLinkDependentModuleFailsWithUndefinedReferenceError) {
  if (!compilation_provider()->SupportsCompileAndLink()) {
    GTEST_SKIP() << "Compilation provider doesn't support CompileAndLink";
  }
#ifdef ABSL_HAVE_THREAD_SANITIZER
  if (GetParam() == "nvjitlink") {
    GTEST_SKIP()
        << "nvjitlink fails with TSAN enabled due to some wrongly unlocked "
           "mutex. Note that this only happens when the compilation fails.";
  }
#endif

  CompilationOptions options;
  EXPECT_THAT(compilation_provider()->CompileAndLink(
                  kDefaultComputeCapability, {Ptx{kDependentPtx}}, options),
              absl_testing::StatusIs(
                  _, AnyOf(HasSubstr("Undefined reference"),
                           HasSubstr("Unresolved extern function"))));
}

TEST_P(CompilationProviderTest, CompileAndLinkMultipleModulesSucceeds) {
  if (!compilation_provider()->SupportsCompileAndLink()) {
    GTEST_SKIP() << "Compilation provider doesn't support CompileAndLink";
  }

  CompilationOptions default_options;
  TF_ASSERT_OK_AND_ASSIGN(
      Assembly assembly,
      compilation_provider()->CompileAndLink(
          kDefaultComputeCapability, {Ptx{kDependentPtx}, Ptx{kDependeePtx}},
          default_options));
  EXPECT_FALSE(assembly.cubin.empty());
}

TEST_P(CompilationProviderTest, CompileAndLaterLinkMultipleModulesSucceeds) {
  if (!compilation_provider()->SupportsCompileToRelocatableModule()) {
    GTEST_SKIP()
        << "Compilation provider doesn't support CompileToRelocatableModule";
  }

  if (!compilation_provider()->SupportsCompileAndLink()) {
    GTEST_SKIP() << "Compilation provider doesn't support CompileAndLink";
  }

  CompilationOptions default_options;
  TF_ASSERT_OK_AND_ASSIGN(
      RelocatableModule module1,
      compilation_provider()->CompileToRelocatableModule(
          kDefaultComputeCapability, kDependentPtx, default_options));
  TF_ASSERT_OK_AND_ASSIGN(
      RelocatableModule module2,
      compilation_provider()->CompileToRelocatableModule(
          kDefaultComputeCapability, kDependeePtx, default_options));
  TF_ASSERT_OK_AND_ASSIGN(
      Assembly assembly,
      compilation_provider()->CompileAndLink(
          kDefaultComputeCapability, {std::move(module1), std::move(module2)},
          default_options));
  EXPECT_FALSE(assembly.cubin.empty());
}

TEST_P(CompilationProviderTest, CancelsOnRegSpill) {
  if (!compilation_provider()->SupportsCompileAndLink()) {
    GTEST_SKIP() << "Compilation provider doesn't support CompileAndLink";
  }
  if (GetParam() == kDriverCompilationProviderName) {
    GTEST_SKIP() << "Driver compilation doesn't support cancel_if_reg_spill";
  }

  // We have to disable optimization here, otherwise PTXAS will optimize our
  // trivial register usages away and we don't spill as intended.
  CompilationOptions options;
  options.cancel_if_reg_spill = true;
  options.disable_optimizations = true;

  EXPECT_THAT(
      compilation_provider()->CompileAndLink(
          kDefaultComputeCapability, {Ptx{kSpillingKernelPrefix}}, options),
      absl_testing::StatusIs(absl::StatusCode::kCancelled));

  // This is to make sure we didn't break the PTX and that's why it was failing
  // in the previous assertion.
  options.cancel_if_reg_spill = false;
  EXPECT_THAT(
      compilation_provider()->CompileAndLink(
          kDefaultComputeCapability, {Ptx{kSpillingKernelPrefix}}, options),
      absl_testing::IsOk());
}

TEST_P(CompilationProviderTest,
       CompileFailsWhenInvalidArchitectureIsRequested) {
  CompilationOptions default_options;
  EXPECT_THAT(compilation_provider()->Compile(CudaComputeCapability{100, 0},
                                              kStandalonePtx, default_options),
              Not(absl_testing::IsOk()));
}

TEST_P(CompilationProviderTest,
       CompileToRelocatableModuleFailsWhenInvalidArchitectureIsRequested) {
  if (!compilation_provider()->SupportsCompileToRelocatableModule()) {
    GTEST_SKIP()
        << "Compilation provider doesn't support CompileToRelocatableModule";
  }

  CompilationOptions default_options;
  EXPECT_THAT(
      compilation_provider()->CompileToRelocatableModule(
          CudaComputeCapability{100, 0}, kStandalonePtx, default_options),
      Not(absl_testing::IsOk()));
}

TEST_P(CompilationProviderTest,
       CompileAndLinkFailsWhenInvalidArchitectureIsRequested) {
  if (!compilation_provider()->SupportsCompileAndLink()) {
    GTEST_SKIP() << "Compilation provider doesn't support CompileAndLink";
  }

  CompilationOptions default_options;
  EXPECT_THAT(compilation_provider()->CompileAndLink(
                  CudaComputeCapability{100, 0}, {Ptx{kStandalonePtx}},
                  default_options),
              Not(absl_testing::IsOk()));
}

TEST_P(CompilationProviderTest, ParallelCompileReturnsSameResult) {
  TF_ASSERT_OK_AND_ASSIGN(
      Assembly reference_assembly,
      compilation_provider()->Compile(kDefaultComputeCapability, kStandalonePtx,
                                      CompilationOptions()));

  // We spawn a hundred threads and schedule parallel calls to `Compile` on
  // them. This is not guaranteed to fail if something was broken, but since we
  // also run this test with thread sanitizer enabled, this should give us a
  // reliable signal whether the locking logic is bogus or not.
  tsl::thread::ThreadPool pool(tsl::Env::Default(), "test_pool", 100);

  for (int i = 0; i < pool.NumThreads(); ++i) {
    pool.Schedule([&]() {
      EXPECT_THAT(
          compilation_provider()->Compile(kDefaultComputeCapability,
                                          kStandalonePtx, CompilationOptions()),
          absl_testing::IsOkAndHolds(reference_assembly));
    });
  }
}

TEST_P(CompilationProviderTest,
       ParallelCompileToRelocatableModuleReturnsSameResult) {
  if (!compilation_provider()->SupportsCompileToRelocatableModule()) {
    GTEST_SKIP()
        << "Compilation provider doesn't support CompileToRelocatableModule";
  }

  TF_ASSERT_OK_AND_ASSIGN(
      RelocatableModule reference_module,
      compilation_provider()->CompileToRelocatableModule(
          kDefaultComputeCapability, kStandalonePtx, CompilationOptions()));

  // We spawn a hundred threads and schedule parallel calls to
  // `CompileToRelocatableModule` on them. This is not guaranteed to fail if
  // something was broken, but since we also run this test with thread sanitizer
  // enabled, this should give us a reliable signal whether the locking logic is
  // bogus or not.
  tsl::thread::ThreadPool pool(tsl::Env::Default(), "test_pool", 100);

  for (int i = 0; i < pool.NumThreads(); ++i) {
    pool.Schedule([&]() {
      EXPECT_THAT(
          compilation_provider()->CompileToRelocatableModule(
              kDefaultComputeCapability, kStandalonePtx, CompilationOptions()),
          absl_testing::IsOkAndHolds(reference_module));
    });
  }
}

TEST_P(CompilationProviderTest, ParallelCompileAndLinkReturnsSameResult) {
  if (!compilation_provider()->SupportsCompileAndLink()) {
    GTEST_SKIP() << "Compilation provider doesn't support CompileAndLink";
  }

  TF_ASSERT_OK_AND_ASSIGN(Assembly reference_assembly,
                          compilation_provider()->CompileAndLink(
                              kDefaultComputeCapability, {Ptx{kStandalonePtx}},
                              CompilationOptions()));

  // We spawn a hundred threads and schedule parallel calls to `CompileAndLink`
  // on them. This is not guaranteed to fail if something was broken, but since
  // we also run this test with thread sanitizer enabled, this should give us a
  // reliable signal whether the locking logic is bogus or not.
  tsl::thread::ThreadPool pool(tsl::Env::Default(), "test_pool", 100);

  for (int i = 0; i < pool.NumThreads(); ++i) {
    pool.Schedule([&]() {
      EXPECT_THAT(compilation_provider()->CompileAndLink(
                      kDefaultComputeCapability, {Ptx{kStandalonePtx}},
                      CompilationOptions()),
                  absl_testing::IsOkAndHolds(reference_assembly));
    });
  }
}

TEST_P(CompilationProviderTest,
       QueryLatestPtxIsaVersionReturnsAValidPtxIsaVersion) {
  CompilationProvider* provider = compilation_provider();
  if (dynamic_cast<SubprocessCompilationProvider*>(provider) ||
      dynamic_cast<NvptxcompilerCompilationProvider*>(provider) ||
      dynamic_cast<NvJitLinkCompilationProvider*>(provider) ||
      dynamic_cast<CompositeCompilationProvider*>(provider)) {
    TF_ASSERT_OK_AND_ASSIGN(int latest_ptx_isa_version,
                            provider->GetLatestPtxIsaVersion());
    EXPECT_GE(latest_ptx_isa_version, 80);
    // Update when PTX 20.0 comes out.
    EXPECT_LE(latest_ptx_isa_version, 200);
  } else {
    EXPECT_THAT(provider->GetLatestPtxIsaVersion(),
                absl_testing::StatusIs(absl::StatusCode::kUnimplemented));
  }
}

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompilationProviderTest);

}  // namespace stream_executor::cuda
